{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reframing Design Pattern\n",
    "\n",
    "The *Reframing design pattern* refers to changing the representation of the output of a machine learning problem. For example, we could take something that is intuitively a regression problem and instead pose it as a classification problem (and vice versa)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at the natality dataset. Notice that for a given set of inputs, the weight_pounds (the label) can take many different values. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from google.cloud import bigquery\n",
    "\n",
    "import matplotlib as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "bq = bigquery.Client()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"\"\"\n",
    "SELECT\n",
    "  weight_pounds,\n",
    "  is_male,\n",
    "  gestation_weeks,\n",
    "  mother_age,\n",
    "  plurality,\n",
    "  mother_race\n",
    "FROM\n",
    "  `bigquery-public-data.samples.natality`\n",
    "WHERE\n",
    "  weight_pounds IS NOT NULL\n",
    "  AND is_male = true\n",
    "  AND gestation_weeks = 38\n",
    "  AND mother_age = 28\n",
    "  AND mother_race = 1\n",
    "  AND plurality = 1\n",
    "  AND RAND() < 0.01\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>weight_pounds</th>\n",
       "      <th>is_male</th>\n",
       "      <th>gestation_weeks</th>\n",
       "      <th>mother_age</th>\n",
       "      <th>plurality</th>\n",
       "      <th>mother_race</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7.187070</td>\n",
       "      <td>True</td>\n",
       "      <td>38</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.312733</td>\n",
       "      <td>True</td>\n",
       "      <td>38</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>6.801261</td>\n",
       "      <td>True</td>\n",
       "      <td>38</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>8.000575</td>\n",
       "      <td>True</td>\n",
       "      <td>38</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>8.811877</td>\n",
       "      <td>True</td>\n",
       "      <td>38</td>\n",
       "      <td>28</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   weight_pounds  is_male  gestation_weeks  mother_age  plurality  mother_race\n",
       "0       7.187070     True               38          28          1            1\n",
       "1       7.312733     True               38          28          1            1\n",
       "2       6.801261     True               38          28          1            1\n",
       "3       8.000575     True               38          28          1            1\n",
       "4       8.811877     True               38          28          1            1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = bq.query(query).to_dataframe()\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEXCAYAAABWNASkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8nXWZ///XdZbs+94lbdKVLmAppewgCooDggx+ZwDlK44jwzjMOIqjzqiMMt/xy6hfHUV+OgwyuFBQULEiqEjZCynd9yVtmqVp2ux7muVcvz/uOyWEpNnOyZ1zzvV8PPLoWe5z39dJk3c+57rv+3OLqmKMMSa2+LwuwBhjTPhZuBtjTAyycDfGmBhk4W6MMTHIwt0YY2KQhbsxxsQgC/c4JCI/FJGvhGld80SkQ0T87v0XReSvw7Fud33PisjHwrW+CWz3/4hIg4jUjfDcu0WkZpLrLRERFZHA1KucOhG5TEQOjHPZSb9vM/1mxA+YCR8ROQoUAv3AALAX+AnwoKqGAFT1zgms669V9U+jLaOqVUDa1Ko+vb2vAotU9aND1v+BcKx7gnXMA+4G5qvqyene/nRS1VeApeFYl4g8AtSo6pfDsT4zNTZyj00fVNV0YD5wH/AF4Efh3shMGX1GwDygMdaD3cQ2C/cYpqqtqroe+EvgYyKyEpwRloj8H/d2nog8LSItItIkIq+IiE9EfooTcr912y6fH9JS+ISIVAEbRmkzLBSRTSLSJiK/EZEcd1vv+FgvIkdF5CoRuQb4F+Av3e3tcJ8/3eZx6/qyiFSKyEkR+YmIZLrPDdbxMRGpclsqXxrteyMime7r6931fdld/1XAc8Bst45HzrCOf3G3c1REPjLk8WtFZJv7/qvdTyTD/ZWI1IrIcRH5nPu6IhHpEpHcIeta7dYYHLbtJBHpFpE89/6XRKRfRDLc+/8mIv/p3k4UkW+535cTblsueaT/E3d720SkXUSeEJGfD/6sDFnmbvf7f1xEPu4+dgfwEeDz7vftt6N938z0sHCPA6q6CagBLhvh6bvd5/Jx2jn/4rxEbwOqcD4FpKnqN4a85gpgGfD+UTb5v4G/AmbhtIe+N44afw98Hfi5u713jbDY7e7XlcACnHbQ94ctcylOm+G9wD0ismyUTd4PZLrrucKt+eNuC+oDQK1bx+2jvL4IyAPmAB8DHhSRwfZGp7u+LOBa4G9F5EPDXn8lsBh4H/AFEblKVeuAF4G/GLLcbcDjqto39MWq2gO86daO+28lcMmQ+y+5t+8DlgCrgEVuzfcMf0MikgD8GngEyAEeA24c4X1nuuv4BPCAiGSr6oPAo8A33O/bB4ev30wvC/f4UYvzCztcH04Iz1fVPlV9RceecOirqtqpqt2jPP9TVd2tqp3AV4C/EHeH6xR9BPi2qh5R1Q7gn4Gbh31q+JqqdqvqDmAH8I4/Em4tNwP/rKrtqnoU+H84QToRX1HVU6r6EvA73FBW1RdVdZeqhlR1J05IXjHstV9zv4e7gP8BbnEf/zHw0SF13gL8dJTtvwRc4b7/c3D+iF4hIknA+cDLIiLAHcBnVLVJVdtx/ojePML6LsTZD/c992fhV8CmYcv0Afe6zz8DdBCmnr0JLwv3+DEHaBrh8W8C5cAfReSIiHxxHOuqnsDzlUAQZ5Q7VbPd9Q1ddwDnE8egoUe3dDHyzt48t6bh65ozgVqa3T9eQ18/G0BELhCRF9x2SitwJ+98/8O/R7Pd278BlotIKXA10Op+8hrJS8C7gdXALpx20hU4IV2uqo04n8hSgC1u660F+L37+HCzgWPD/rgP/79uVNX+IfdH+x4bj1m4xwEROR8nuF4d/pw7cr1bVRcA1wOfFZH3Dj49yirHGtkXD7k9D2e014DTrkgZUpeft4fMWOutxdlJPHTd/cCJMV43XINb0/B1HZvAOrJFJHXY62vd2+uA9UCxqmYCPwRk2OuHf49q4XS75Rc4o/fbGH3UDrARZ9R8I/CSqu511/VnvNWSaQC6gRWqmuV+ZarqSIF8HJjjjvZHqnMsNsXsDGLhHsNEJENErgMeB37mtgCGL3OdiCxyf6FbcQ6fDLlPn8DpSU/UR0VkuYikAPcCT6rqAHAQSHJ3OAaBLwOJQ153AigRkdF+Lh8DPiMipSKSxls9+v5Rlh+RW8svgH8XkXQRmQ98FvjZRNYDfE1EEkTkMuA64An38XSgSVV7RGQtcOsIr/2KiKSIyArg48DPhzz3E5x9C9dzhnBX1S5gC/B3vBXmG3E+KbzkLhMC/hv4jogUAIjIHBEZaX/J6zj//3eJSEBEbgDWjvE9GGqyPy8mAizcY9NvRaQd5yP1l4Bv4wTISBYDf8Lpnb4O/H+q+oL73P8Fvux+nP/cBLb/U5ydcnVAEvAP4By9A3wKeAhnlNyJszN30GA4NorI1hHW+7C77peBCqAH+PsJ1DXU37vbP4LziWadu/7xqgOacUbcjwJ3qup+97lPAfe6/wf34PwhGe4lnHbY88C3VPWPg0+o6ms4f2C3qmrlCK8dvp4gb/XGX8L54/LykGW+4G7rDRFpw/n/fkefXFV7gT/H2VHagvPp4Wng1Bg1DPoRTkupRUSeGudrTISIXazDmJlHRDYA61T1IY/rKAN+qKr/42UdZuJs5G7MDOPuI1nN21s107XtK9zj7QPiTPtwDs4OWBNlYvUMQ2Oikoj8GPgQ8Gn3sMXpthSnjZSK07L6sKoe96AOM0XWljHGmBhkbRljjIlBnrVl8vLytKSkxKvNG2NMVNqyZUuDqo50EtrbeBbuJSUlbN682avNG2NMVBKRsQ6PBawtY4wxMcnC3RhjYpCFuzHGxCALd2OMiUEW7sYYE4Ms3I0xJgZZuBtjTAyycDfGmBhk4W6MMTHIZoU0ZgrWlVWNe9lbL5gXwUqMeTsbuRsTBq3dfWw83EBrd5/XpRgD2MjdmCnr7Q/x441HqWvr4Xc7j7O4MI0LF+RyVlGG16WZOGYjd2OmQFX55dYaTrT1cNPqubx7aT51rT385PVKDtd3eF2eiWMW7sZMwcuHGth1rJX3rSjivPnZXL28iLvft5TM5CDP7T2BXQzHeMXC3ZhJ2ljewB/31HH2nEwuX5x3+vGg38eVSwuoauri4AkbvRtvWLgbM0nf23CIzOQgN62ei4i87bnz5meTk5rAc/vqbPRuPGHhbswkHK7v4I0jTawtzSEh8M5fI79PeM9ZBdS29LD3eJsHFZp4Z+FuzCQ8VlZFwCecNz971GVWFWeRn5bIc3tPELLRu5lmdiikMeM0eMJS30CIR8uqOGtWBulJwVGX94kzev/55mrKT1rv3UwvG7kbM0G7j7XS3TfA2pKcMZddPjuDhICPPbWt01CZMW+xcDdmgjZVNJGbmsCC/NQxlw36fSwtTGdvbRsDIWvNmOlj4W7MBJxo66GyqYu1pTn4hh0hM5qVczLp7B3gzaNNEa7OmLdYuBszAZuPNuH3Cavnjb4jdbglhWkEfMLvd9dFsDJj3s7C3ZhxCqmyu7aNJYXppCaO/1iExICfJYXp/H53HSFrzZhpYuFuzDjVNHfT2t3HytkTnxBsxewM6tp62F7TEoHKjHknC3djxmnPsVb8IpOa7fGsogyCfmvNmOkzrnAXkWtE5ICIlIvIF8+w3E0ioiKyJnwlGuM9VWV3bSuLCtJITvBP+PXJCX4uXpjHs7uP23QEZlqM2TgUET/wAHA1UAO8KSLrVXXvsOXSgU8DZZEo1Bgv7alto7mrjyuXFkx6HTmpCbx0sJv/98eDzM5KPv24XaHJRMJ4Ru5rgXJVPaKqvcDjwA0jLPdvwH8APWGsz5gZ4Zldx/EJLJ81+QtwnFWUDsAhO1vVTIPxhPscoHrI/Rr3sdNEZDVQrKq/O9OKROQOEdksIpvr6+snXKwxXlBVnt1dx4K8NFImcJTMcOlJQQozEjls4W6mwZR3qIqID/g2cPdYy6rqg6q6RlXX5OfnT3XTxkyLAyfaqWjoZMWcqV82b1F+GkcbO+kbCIWhMmNGN55wPwYUD7k/131sUDqwEnhRRI4CFwLrbaeqiRXP7qpDptiSGbSwII3+kFLZ2BWGyowZ3XjC/U1gsYiUikgCcDOwfvBJVW1V1TxVLVHVEuAN4HpV3RyRio2ZZn/ce4Lz5+eccQbI8SrNTcUn2PVVTcSNGe6q2g/cBfwB2Af8QlX3iMi9InJ9pAs0xku1Ld3sO97Ge5dN/iiZoRKDfopzUmwKYBNx49o7pKrPAM8Me+yeUZZ999TLMsY7g/O2A7xxpBGA7r6BsIzcwem7b9h/kq7eflIS7JIKJjLsDFVjzuBAXTs5qQnkpyWGbZ2LCtJQ4Eh9Z9jWacxwFu7GjKK3P8Th+g6WFaW/4wLYUzE3O4WEgI9y67ubCLJwN2YU5Sc76A8pSycxl8yZ+H3CgrxUO97dRJSFuzGj2F/XRmLAR0leStjXvTA/jcbOXpo7e8O+bmPAwt2YEYVUOVDXzuLCdAK+8P+aLCpIA+yQSBM5Fu7GjKC2pZv2U/0sc+eDCbeC9ERSE/xUNNhOVRMZFu7GjGB/XTsCLCmMTLiLCKV5qVQ0dNoUwCYiLNyNGcGBunaKc1ImdDm9iSrNT6Olu4+a5u6IbcPELwt3Y4bpOtVPbUs3i92+eKSU5qUCb50oZUw4WbgbM8zhhk6Ut3Z6RkpBeiIpCX7eONIU0e2Y+GThbsww5SfbSQz4mJsd/kMgh/K5fXcbuZtIsHA3ZghVpfxkBwvy0/D7wndW6mhK81I51tJNdZNNAWzCy8LdmCEqG7to7uqLeEtm0II8ZztlFdaaMeFl4W7MEK+WNwCwOH96wr0gI5GslKC1ZkzYWbgbM8SrhxrISg6Sm5YwLdvziXBBaQ5lFRbuJrws3I1xDYSUjYcbWFSQFtZZIMdy4YJcqpu6OdZix7ub8LErBRjj2lnTQltPPwunqd8+qMmdPOw7zx1k9bxsAG69YN601mBij43cjXG9esjpty+cpn77oMKMJJKDNs+MCS8Ld2Ncr5Y3sGJ2BmkRnHJgJD4RStx5ZowJFwt3Y4CevgG2Vbdw0YJcT7ZfmpdKU2cvrd19nmzfxB4Ld2OAnTWt9PaHWFua48n2F7jzzFQ02PzuJjws3I0B3jzqnER0fok34V6UmURS0GetGRM2Fu7G4JwhuqQwjezU6Tm+fTifCCW5qRypt3A34WHhbuJe/0CIrZXNnrVkBpXmpdLY2Uub9d1NGFi4m7i373g7Haf6PWvJDCo93Xe30buZOgt3E/c2uf12r0fuszKTSQxY392Eh52hauLaurIqfrmlhuyUIC/sr/e0Fr/P6btbuJtwsJG7iWuqytHGztMtEa+V5qVS33GKk+09XpdiopyFu4lr9e2n6OodoCR35oQ7wCab391MkYW7iWtHG50rIJXMkJH77KxkEgI+m9/dTJmFu4lrRxs7SUsMkOvR8e3DOX33FLtotpkyC3cT1442dlKSmzKt87ePZUFeGuUnO6hvP+V1KSaKWbibuFXX2kNLVx/zZ0i/fdBg391aM2YqLNxN3Npa1QzAvJwUjyt5u9lZyaQlBizczZRYuJu4tbWymYBPmJWV5HUpb+P3CeeXZPO6hbuZAgt3E7e2VjUzOyuZgG/m/RpcuCCXI/WdnGyz493N5My8n2pjpsGp/gF2H2tj/gxryQy60L1oyBt2vLuZJAt3E5f21LbROxCieIaG+4rZGaQnBnj9sLVmzOSMK9xF5BoROSAi5SLyxRGev1NEdonIdhF5VUSWh79UY8Jna6W7MzV3ZoZ7wO/j/NIcyqzvbiZpzHAXET/wAPABYDlwywjhvU5Vz1bVVcA3gG+HvVJjwmhrVTNzspLJSAp6XcqoLlqQy5GGTk5Y391MwnhG7muBclU9oqq9wOPADUMXUNW2IXdTAQ1ficaE39bKFlbPz/a6jDM63Xe30buZhPGE+xygesj9GvextxGRvxORwzgj938YaUUicoeIbBaRzfX13k6vauJXbUs3dW09rJ6X5XUpZ7R8dgbpSdZ3N5MTth2qqvqAqi4EvgB8eZRlHlTVNaq6Jj8/P1ybNmZCBk9eOm+Gj9z9PuGC0lw73t1MynjC/RhQPOT+XPex0TwOfGgqRRkTSVsqm0kK+lg2K8PrUsZ08cJcKhu7ONbS7XUpJsqMJ9zfBBaLSKmIJAA3A+uHLiAii4fcvRY4FL4SjQmvrVUtnDMni6B/5h8JfNFCp+9urRkzUWNeZk9V+0XkLuAPgB94WFX3iMi9wGZVXQ/cJSJXAX1AM/CxSBZtzGT19A2wt7aVv7qk1OtSzmhdWRUAIVVSEvw8+kYlHz5vrsdVmWgyrmuoquozwDPDHrtnyO1Ph7kuYyJi7/E2+gaUc2f4ztRBPhEW5KdxpKETVZ1RUxObmW3mfy41Joy2V7UAsKp4Zu9MHWpBXiqt3X1UuleNMmY8LNxNXNle3UJhRiJFmTNrJsgzWZifBsBG67ubCbBwN3FlR00Lq4qjoyUzKC8tgYykABsPN3hdioki4+q5GxML/vvlI1Q2dnFWUcbpHZbRQNy++xtHGq3vbsbNRu4mbtQ0Oz3r4uxkjyuZuAV5qTR09HLoZIfXpZgoYeFu4kZ1czcCzInCcB/su9vx7ma8LNxN3Khu6qIgI5HEgN/rUiYsOzWB4pxkXiu3vrsZHwt3ExdUlZrmboqzZ+b87eNxycI83jjSyEDIJl01Y7NwN3HhaGMX3X0DUR3uFy/Ko62nn93HWr0uxUQBO1rGxIXt1c5MkHNzoq/fPqiu1bloxwMvlPPupQWnH7/1gnlelWRmMBu5m7iwvaqFBL+PwozoOXlpuLTEAEUZSRyutyNmzNgs3E1c2F7TyuysZHxRfoz4wvxUKhu76BsIeV2KmeEs3E3MO9U/wL7aNoqjuCUzaGFBGv0htXlmzJgs3E3M23e8nd6BUFTvTB1UmpuKT7DWjBmThbuJeTuqnZkg50bhyUvDJQb9FGenWLibMVm4m5i3o6aFvLREMpODXpcSFgsL0jjW3E1374DXpZgZzMLdxLwd1S2sKs6MmQm3FuanoUBFg43ezegs3E1Ma+vp43B9J++aG13T/J5JcU4yQb9Qbq0ZcwYW7iam7apxzuZ8V5TN4X4mAZ+P0rxUyk92el2KmcEs3E1M2+7uTD1nbqbHlYTXooJ0GjpO0dLV63UpZoaycDcxbWdNCyW5KWSlJHhdSlgtKnCmAC63+d3NKCzcTUzbUd0aUy2ZQYXpiaQnBeziHWZUFu4mZtW19lDX1hNTO1MHiQiL8tM4XN9ByKYANiOwcDcxa0eN02+PxZE7OK2Zrt4B9tS2eV2KmYEs3E3M2lHdQsAnrJid4XUpETHYd3+lvN7jSsxMZOFuYtbOmlaWFqWTFIy+y+qNR3pSkKKMJF49ZJfeM+9k4W5iUiik7KhpidmWzKBFBWlsPtpsUxGYd7BwNzHpSEMn7T39rIrBnalDLS5Io3cgRFlFo9elmBnGwt3EpMGTl1bNi+1wL8lLJSHgs9aMeQcLdxOTtlU1k54YYFF+mtelRFTQ72NtSQ4vH7KdqubtLNxNTNpe3cI5xZn4fLExE+SZXLY4j4MnOk5fQNsYsHA3Mai7d4D9de2cW5ztdSnT4vIl+QA2ejdvY+FuYs6uY60MhJRVMX6kzKCzitIpSE/k5YMW7uYtAa8LMCbcHn61AnCOmFlXVuVxNZEnIly2OJ/n959gIKT446AVZcZmI3cTc6qbu8hOCZKWGD9jl8uX5NHS1ceuY61el2JmCAt3E3NqmrspzknxuoxpdemiPESw1ow5zcLdxJS61h5au/sozo6vcM9NS2Tl7EwLd3PauMJdRK4RkQMiUi4iXxzh+c+KyF4R2Skiz4vI/PCXaszYtlc3AzAvzkbu4LRmtlW30NbT53UpZgYYM9xFxA88AHwAWA7cIiLLhy22DVijqucATwLfCHehxozHtqoW/D5hVmaS16VMu8sX5zMQUjaW29mqZnwj97VAuaoeUdVe4HHghqELqOoLqtrl3n0DmBveMo0Zn23VLczOTCLgj7+O4+r52aQlBnjpoIW7Gd+hkHOA6iH3a4ALzrD8J4Bnp1KUMZPRPxBiV01rzM8nM9zQwz2Lc1J4dvdxvn7jSkTskMh4FtbhjYh8FFgDfHOU5+8Qkc0isrm+3nb8mPDaX9dOd99A3O1MHWpJYRotXX0crrdrq8a78YT7MaB4yP257mNvIyJXAV8CrlfVUyOtSFUfVNU1qromPz9/MvUaM6otlc7O1Pm58RvuSwvTAXhhvw2e4t14wv1NYLGIlIpIAnAzsH7oAiJyLvBfOMF+MvxlGjO2zZXNFGUkkZUc9LoUz2SlJFCQnsiLB+3XMN6NGe6q2g/cBfwB2Af8QlX3iMi9InK9u9g3gTTgCRHZLiLrR1mdMRGz5WgT55Vkx32veWlhOpsqmug81e91KcZD4zo/W1WfAZ4Z9tg9Q25fFea6jJmQ2pZualt7+OT8+JgJ8kyWFKXzSnkDr5U38L4VRV6XYzwSf8eLmZi02e23r5mf43El3pufm0Jqgp8X7WzVuGbhbmLClqNNpCT4WTYr3etSPBfw+bhkUR4vHahHVb0ux3jEwt3EhM2VzawqzorLk5dGcuVZBRxr6ebQSTskMl7Fz5yoJmZ1nOpn3/E27rpykdelzBgtXc78Mt957iCXLX7rsONbL5jnVUlmmtkwx0S97VUthBTOK7F++6DM5CCFGYkcONHudSnGIzZyN1FtXVkVz+87gQCHT3ZwrLnb65JmjKWF6bxa3kBP3wBJQb/X5ZhpZiN3E/Uqm7oozEiyABtm2awMQoqN3uOUhbuJagMhpaqpK66nHBhNcU4KqYkB9h1v87oU4wELdxPVTrT10NsfsnAfgU+EZUXpHKhrpz8U8rocM80s3E1UO9rYCcD83FSPK5mZls3K4FR/iIr6Tq9LMdPMwt1EtYqGTrJSgmSnJHhdyoy0qCCNoF/Ya62ZuGPhbqKWqlLR0EmpjdpHFfT7WFyQzv66djtbNc5YuJuoVX6yg67eAUrzLNzPZPmsDFq7+6ht6fG6FDONLNxN1CqraAKwcB/D0qJ0BKw1E2cs3E3UKqtoIiMpQE6q9dvPJDUxwPzcVDskMs5YuJuopKpsqmikJC817i/OMR4rZmdQ19ZDRYMdNRMvLNxNVKps7OJE2ylryYzTyjmZADy9o9bjSsx0sXA3UWmT228vsSNlxiUzOcj83BSe3nnc61LMNLFwN1GprKKJnFTnYtBmfM6Zm8WBE+0ctLlm4oKFu4lKZRWNrC3JsX77BKycnYFPrDUTLyzcTdQ51tJNTXM3Fyyw+dsnIj0pyEULc3l653E7oSkOWLibqFN2pBGAtaUW7hN13TmzOdLQyZ5aOywy1lm4m6iz8XAj2SlBlhVleF1K1LlmRREBn9iO1Thg4W6iiqry+uFGLlqYi89n/faJyk5N4NLFeTy9s9ZaMzHOwt1ElcrGLo61dHPRwjyvS4laHzxnNjXN3Wytava6FBNBFu4mqmw87PTbL16Y63El0WldWRWt3X0E/cI3fn+AdWVVXpdkIsTC3USV1w43UJSRxAI7M3XSkoJ+ls3KYNexVrtCUwyzcDdRIxRS3jjcyMULc+349ilaVZxFV+8Ah050eF2KiZCA1wUYMx7ryqqoa+2hsbMXn4i1E6ZocUE6KQl+tle3eF2KiRAbuZuocbjeGWUuyLeWzFT5fcI5c7PYd7yN9p4+r8sxEWDhbqLG4foOclMTyLLrpYbFquIs+kPK73fXeV2KiQALdxMVBkLO9VIX5qd5XUrMKM5OJic1gae2H/O6FBMBFu4mKtS2dHOqP2QtmTASEVYVZ7HxcCPHW7u9LseEmYW7iQqHTrYjwAIbuYfVucVZqMJT22ymyFhj4W6iwoG6duZkJ5OWaAd4hVNuWiLnzc/mV1trbDqCGGPhbma8xo5T1DR3s7Qo3etSYtJNq+dy6GQHu4/ZTJGxxMLdzHgvHqhHgbMKbRbISLj27FkkBHz8cmuN16WYMLJwNzPehgMnSU8MMCsryetSYlJmSpCrlxWyfkctvf02HUGsGFe4i8g1InJARMpF5IsjPH+5iGwVkX4R+XD4yzTxqm8gxMsH61lSlI7PphyImJvOm0NTZy8vHaz3uhQTJmPunRIRP/AAcDVQA7wpIutVde+QxaqA24HPRaJIE7+2VDbT3tPP0kLrt0fKurIqBkJKamKA//zTQerbTwFw6wXzPK7MTMV4Ru5rgXJVPaKqvcDjwA1DF1DVo6q6E7DPdCasNuw/SdAvLC6wQyAjye8TVs3NZP/xdrpO9XtdjgmD8YT7HKB6yP0a97EJE5E7RGSziGyur7ePf2ZsG/af5ILSXBKDfq9LiXmr52czoMo2m0wsJkzrDlVVfVBV16jqmvz8/OnctIlC1U1dlJ/s4MqzCrwuJS7MykxmbnYym4422THvMWA84X4MKB5yf677mDER9fy+EwC8x8J92lxQmkN9+ymONnZ5XYqZovGE+5vAYhEpFZEE4GZgfWTLMgZ+s6OWs4rSKbWrLk2bs+dkkRT0sami0etSzBSNGe6q2g/cBfwB2Af8QlX3iMi9InI9gIicLyI1wP8C/ktE9kSyaBP7jtR3sK2qhRvPndTuHTNJCQEfq4qz2V3bRlNnr9flmCkYV89dVZ9R1SWqulBV/9197B5VXe/eflNV56pqqqrmquqKSBZtYt9T247hE/iQhfu0W1uaw0BI+eUWO2M1mtkZqmbGCYWUX207xiWL8ijMsLNSp1tRRhLzc1JYt6nKdqxGMQt3M+NsrmymprnbWjIeWluaQ0VDJ6+WN3hdipkkC3cz4/x6Ww0pCX7ev6LI61Li1so5mRRmJPL9DeVel2ImycLdzCg9fQM8vfM416woItXmbvdM0O/jjssXUlbRxKaKJq/LMZNg4W5mlA37T9Le08+Nq60l47Vb184jLy2B+zcc8roUMwk2NDIzhqry8KsVZCYHqWzsorqpyuuS4lpygp9PXraA//vsfrZVNXPuvGyvSzITYCN3M2O8fqSRzZXNXL4k36b3nQHWlVWREPCRkuDni7/cxboy+2MbTSzczYyTgyv9AAAPz0lEQVRx//PlFKQnsma+jRBnisSAn0sW5XHgRDs1zTYlQTSxcDczwqaKJl4/0sjfXLGQoN9+LGeSixbkkprg57c7agmF7Lj3aGG/RWZGuH/DIfLSErh1rV0gYqZJCvr5s7NnUd3czaObrDUTLSzcjee2VjXzyqEGPnnZApITbN72mWhVcRYL81P5xrP7OdnW43U5Zhws3I2nBkLK13+3j+yUIB+9cL7X5ZhRiAg3vGsOpwZCfO3pvWO/wHjOwt146uFXK9hc2cxXrltuJy3NcHnpidx15SJ+t/M4z+094XU5ZgwW7sYzh060880/HuDq5YU2j0yU+JsrFrB8Vgaf/cV2jtR3eF2OOQMLd+OJ/oEQdz+xg9QEP1+/8WzEjmuPCokBP/9123nO9AQ/3UJ7T5/XJZlR2Odg44n7N5Szs6aVW9bOs4/4UaY4J4UHbl3NR39Uxmd+vp0Hb1uDz2d/nGcaG7mbaferrTV89/lDnFucxdlzMr0ux0zAurIq1pVVUdHQyQdWFvGnfSe57UdlNu/7DGQjdzMtBk9dP3SynR9vPMqCvFTrs0e5ixbk0tjZy2uHG7n36b3cc91ya6/NIBbuZtrUtnTzaFkVBelJfPTC+QTsTNSoJiJcd/YsBPif146iCv/6QQv4mcLC3UyL2pZuHn6tgpSgn9svLiEpaCcrxQIR4dqzZ7F8VgYPvVrBqf4B/u2GlfaHewawcDcRt62qmYdePUJSwM8nLi0lIznodUkmjESEL127jMSgjwdeOEx9ey/333KunW3sMfvzaiJqU0UTt/1oEykJAT55+QJy0xK9LslEgIjwT+8/i3tvWMHz+09w60Nv0NTZ63VZcc3C3UTMhv0n+N8Pl1GYkcgnL1tAdkqC1yWZCBk8iibg83Hr2nnsqmnlph9spKrRpgn2ioW7iYgnt9TwyZ9sYXFBOj//m4vItFZM3FgxO5NPXFpKU2cvf/6D19hV0+p1SXHJwt2Elary4MuH+dwTO7hwQQ6P3XEhedaKiTvzc1P5+MUl9IeUm36wka+u32NXcppmFu4mbAZCytd+u5evP7Ofa8+excO3n0+aTQYWtwoykrjzioXkpiXwk9ePsqmiyeuS4oqFuwmLnr4BPvXoFh7ZeJRPXFrK/becS2LAjpaIdxlJQe64bAGLCtJ4avsx7nt2v13NaZrYsMpM2uDH7I5T/fzsjUqqm7q49uxZfOW65R5XZmaSxKCf2y4s4bc7avnhS4epburim//rHFISLH4iyb67ZkrqWnv4yRtH6ejp55a181g5J9N6q+Yd/D7hhlWzee+yAu77/X4O13fw4G1rmJeb4nVpMcvaMmbS9te18cOXDxMKKXdcvoCVNgmYOQMRIT0pyMcuKqGysYv3/+fLtqM1gizczYSFQsp3/3SIn75eSV5aAn/77kXMzbYRmBmfJYXpfOrdC8lMDvLjjUd5bm8d/QMhr8uKORbuZkKaOnu5/ZE3+c6fDrKqOIs7Lltox7CbCctNS+TOKxayen42Lxyo59b/LuN4a7fXZcUU67mbcdtY3sA/PbmT+vZT/PuNK0GxGQDNpCUEfNy0ei4L81N5anst7/nWS3zo3Dmn5/i/9YJ5HlcY3WzkbsbUcaqfLz+1i1sfKiMh4OOJOy/iIxfMt2A3YbGqOJu73r2InNQEHttUxZNbqunpG/C6rKhnI3czqoGQsn7HMb71h4PUtnRz6aI8rlpWyJ7aNvbUtnldnokheelOm2bD/pO8eOAkRxo6mZ+bwnuXFXpdWtQSry6PtWbNGt28ebMn2zZndqp/gGd31fG95w9xpKGTs4rSuWJJPvNzU70uzcSBqsZOfrntGPXtp3jf8kLu+eBy22E/hIhsUdU1Yy5n4W5Ulbq2HrZWtvCHPXVs2H+SjlP9LC1M5x+vWsz7VxTx+JvVXpdp4kh/KERX7wDf/dMhQqr8+eo5fPySUpYUpntdmufGG+7WlokDqkpLVx+1rd3UtfZQ29pDXWs3x1t7qGnqZn9dG209/QDkpCZw7dmz+MDZRVy+ON+uam88EfD5uPOKEj74rtnc//whfrX1GI9tqubihblcubSA80tzWDE7g6Bd8WlU4wp3EbkG+C7gBx5S1fuGPZ8I/AQ4D2gE/lJVj4a3VDOW/oEQlU1dHDrRzoG6Dg7Xd1DR0MnBE+2c6n/7ccQ+ceb9yEwJctasDIoykpiVmcQ/vX+pXSLNzBhzspK576Zz+Pw1Z/HYpiqe2FzNvz+zD4CkoI9ZmcnkpyXS1dtPQsBP0C8E/T4SAj4S/D4uWZxHZnKQ3NQEslMSyE9PJDc1IS4GLWO2ZUTEDxwErgZqgDeBW1R175BlPgWco6p3isjNwI2q+pdnWu90tGVUlf6Q0tsfon9A6Qs5/4ITbggEfT6CAZ/zQ+HzheU/PRRSegdC9A2E6BtQBkJKSJ0vvwg+nxDwvfVDGPDJuI486ekboLW7j9buPmcE3tLNsZZujjR0cvhkB0fqO+l1TwYRICslSF5aIrlpCeSkJpKZHCQrOUhmcpC0pAA+O9rFRKGrlhXw5tFmtlc3c7y1h/r2U6d/9vuG/N6NJugXCtKTKMp0vzKSyE9PJDslSHZKAulJQRKDPpICfhICzu+mAD4RfCKIONMpBPxCgt834d/jqQpnW2YtUK6qR9wVPw7cAOwdsswNwFfd208C3xcR0Qg09B/bVMUPXjyMoqjifikDqoTUGb32DSh9AyF6B0JMtIKAT0gIOP9hAZ/gd7+Edx7THXL/eIRCevoHqm8gRP8EZ72TwT8yfiHg9+ETTv9A9bvr7u0feb0CZKcmUJCeyAULcihIT6IwI5GC9CQSAjYCN7HnT/tOAlCal0ZpXtqIy4TU+b051R+iu3eAzt5+Ok8N0N7TR1t3P209fTR39lLZ2Elrd98Z/xiMlwgk+J1PDAH3d9kv8tbvszjLCMLd71vCDavmTHmbZzKecJ8DDN2bVgNcMNoyqtovIq1ALtAwdCERuQO4w73bISIHJlP0BOQNryEWHY2T94m9z1gSD+8RRnmfr3xhSuucP56FpnWHqqo+CDw4XdsTkc3j+fgS7ex9xpZ4eJ/x8B7B2/c5ns/tx4DiIffnuo+NuIyIBIBMnB2rxhhjPDCecH8TWCwipSKSANwMrB+2zHrgY+7tDwMbItFvN8YYMz5jtmXcHvpdwB9wDoV8WFX3iMi9wGZVXQ/8CPipiJQDTTh/AGaCaWsBeczeZ2yJh/cZD+8RPHyfnp2haowxJnLsWDljjIlBFu7GGBODYjLcRaRYRF4Qkb0iskdEPu11TZEiIn4R2SYiT3tdS6SISJaIPCki+0Vkn4hc5HVNkSAin3F/XneLyGMikuR1TeEgIg+LyEkR2T3ksRwReU5EDrn/ZntZYziM8j6/6f7c7hSRX4tI1nTVE5PhDvQDd6vqcuBC4O9EZLnHNUXKp4F9XhcRYd8Ffq+qZwHvIgbfr4jMAf4BWKOqK3EOXpgpByZM1SPANcMe+yLwvKouBp5370e7R3jn+3wOWKmq5+BM4/LP01VMTIa7qh5X1a3u7XacMIjsub4eEJG5wLXAQ17XEikikglcjnNEFqraq6ot3lYVMQEg2T1XJAWo9biesFDVl3GOohvqBuDH7u0fAx+a1qIiYKT3qap/VNV+9+4bOOcJTYuYDPehRKQEOBco87aSiPhP4PNALF86vhSoB/7HbT89JCIxd9UQVT0GfAuoAo4Drar6R2+riqhCVT3u3q4D4uGSS38FPDtdG4vpcBeRNOCXwD+qakxdF05ErgNOquoWr2uJsACwGviBqp4LdBIbH+Hfxu0534Dzx2w2kCoiH/W2qunhnvAY08dki8iXcNrFj07XNmM23EUkiBPsj6rqr7yuJwIuAa4XkaPA48B7RORn3pYUETVAjaoOfvJ6EifsY81VQIWq1qtqH/Ar4GKPa4qkEyIyC8D996TH9USMiNwOXAd8ZDrP3I/JcBdnbt4fAftU9dte1xMJqvrPqjpXVUtwdrxtUNWYG+mpah1QLSJL3Yfey9unm44VVcCFIpLi/vy+lxjccTzE0ClLPgb8xsNaIsa90NHngetVtWs6tx2T4Y4zqr0NZzS73f36M6+LMpP298CjIrITWAV83eN6ws79ZPIksBXYhfO7GROn6IvIY8DrwFIRqRGRTwD3AVeLyCGcTy33nWkd0WCU9/l9IB14zs2hH05bPTb9gDHGxJ5YHbkbY0xcs3A3xpgYZOFujDExyMLdGGNikIW7McbEIAt3Y4yJQRbuJiq5c8yccaZPEXlERD48wuMlInJr5KqbHBF5UUTWeF2HiQ0W7iYqqepfq+pkz1QtAWZcuBsTThbuxlMi8k8i8g/u7e+IyAb39ntE5FEReZ+IvC4iW0XkCXcyuLeNckXkEyJyUEQ2ich/i8j3h2zichHZKCJHhozi7wMuc88Y/Mwodd0uIr9xt3NIRP51yHOfdS+osVtE/tF9rGTYRRo+JyJfHVLrf7j1HRSRy9zHk0XkcfcCJL8Gkt3H/e6njt0ismu0Go05k4DXBZi49wpwN/A9YA2Q6E76dhmwE/gycJWqdorIF4DPAvcOvlhEZgNfwZlMrB3YAOwYsv5ZwKXAWTjzmTyJM6vk51T1ujFqWwusBLqAN0XkdzizF34cuAAQoExEXgKax1hXQFXXutNg/CvOKfd/C3Sp6jIROQdn6gFwpliY4160g+m8eo+JHTZyN17bApwnIhnAKZy5OdbghHs3sBx4TUS240wwNX/Y69cCL6lqkzub4hPDnn9KVUNuC2eic4Y/p6qNqtqNM0vjpe7Xr1W1U1U73McvG8e6Bmcm3YLTFgLnIiQ/A1DVnTh/zACOAAtE5H534qmYmq7aTA8buRtPqWqfiFQAtwMbcQLuSmARUIETsLdMYROnhtyWiZY3xv2h+nn7YGn49U8H6xhgjN87VW0WkXcB7wfuBP4C50IPxoybjdzNTPAK8DngZff2ncA2nMuSXSIiiwBEJFVElgx77ZvAFSKS7V6e7qZxbK8dZ6a+sVwtzoWck3EuA/eaW9+H3Kl5U4Eb3cdOAAUikisiiTjzd4/lZdwduyKyEjjHvZ0H+FT1lzhtqVicv95EmI3czUzwCvAl4HW3t94DvKKq9e6FDh5zAxOcsDs4+EJVPSYiXwc24Vy/cj/QOsb2dgIDIrIDeERVvzPKcptwLvgyF/iZqm4G5xBL9zmAh1R1m/v4ve7jx9w6xvIDnMsH7sOZu33wqlpz3McHB1/TdlFlEztsyl8T9UQkTVU73JH7r4GHVfXXU1zn7cAaVb0rHDUaM92sLWNiwVfdHa67cfr0T3lcjzGes5G7iWsi8n7gP4Y9XKGqN3pRjzHhYuFujDExyNoyxhgTgyzcjTEmBlm4G2NMDLJwN8aYGPT/A6Wntqp7pXV6AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig = sns.distplot(df[[\"weight_pounds\"]])\n",
    "fig.set_title(\"Distribution of baby weight\")\n",
    "fig.set_xlabel(\"weight_pounds\")\n",
    "fig.figure.savefig(\"weight_distrib.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7.497811242931211"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#average weight_pounds for this cross section\n",
    "np.mean(df.weight_pounds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9896963447035907"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.std(df.weight_pounds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weeks=36 age=28 mean=6.734255476277215 stddev=1.1628149516815478\n"
     ]
    }
   ],
   "source": [
    "weeks = 36\n",
    "age = 28\n",
    "query = \"\"\"\n",
    "SELECT\n",
    "  weight_pounds,\n",
    "  is_male,\n",
    "  gestation_weeks,\n",
    "  mother_age,\n",
    "  plurality,\n",
    "  mother_race\n",
    "FROM\n",
    "  `bigquery-public-data.samples.natality`\n",
    "WHERE\n",
    "  weight_pounds IS NOT NULL\n",
    "  AND is_male = true\n",
    "  AND gestation_weeks = {}\n",
    "  AND mother_age = {}\n",
    "  AND mother_race = 1\n",
    "  AND plurality = 1\n",
    "  AND RAND() < 0.01\n",
    "\"\"\".format(weeks, age)\n",
    "df = bq.query(query).to_dataframe()\n",
    "print('weeks={} age={} mean={} stddev={}'.format(weeks, age, np.mean(df.weight_pounds), np.std(df.weight_pounds)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comparing categorical label and regression\n",
    "\n",
    "Since baby weight is a positive real value, this is intuitively a regression problem. However, we can train the model as a multi-class classification by bucketizing the output label. At inference time, the model then predicts a collection of probabilities corresponding to these potential outputs. \n",
    "\n",
    "Let's do both and see how they compare. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "from tensorflow import keras\n",
    "from tensorflow import feature_column as fc\n",
    "from tensorflow.keras import layers, models, Model\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./data/babyweight_train.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll use the same features for both models. But we need to create a categorical weight label for the classification model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare inputs\n",
    "df.is_male = df.is_male.astype(str)\n",
    "\n",
    "df.mother_race.fillna(0, inplace = True)\n",
    "df.mother_race = df.mother_race.astype(str)\n",
    "\n",
    "# create categorical label\n",
    "def categorical_weight(weight_pounds):\n",
    "    if weight_pounds < 3.31:\n",
    "        return 0\n",
    "    elif weight_pounds >= 3.31 and weight_pounds < 5.5:\n",
    "        return 1\n",
    "    elif weight_pounds >= 5.5 and weight_pounds < 8.8:\n",
    "        return 2\n",
    "    else:\n",
    "        return 3\n",
    "\n",
    "df[\"weight_category\"] = df.weight_pounds.apply(lambda x: categorical_weight(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>weight_pounds</th>\n",
       "      <th>is_male</th>\n",
       "      <th>mother_age</th>\n",
       "      <th>plurality</th>\n",
       "      <th>gestation_weeks</th>\n",
       "      <th>mother_race</th>\n",
       "      <th>weight_category</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7.749249</td>\n",
       "      <td>False</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>40</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.561856</td>\n",
       "      <td>True</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>40</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.187070</td>\n",
       "      <td>False</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>34</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6.375769</td>\n",
       "      <td>True</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>36</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7.936641</td>\n",
       "      <td>False</td>\n",
       "      <td>12</td>\n",
       "      <td>Single(1)</td>\n",
       "      <td>35</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   weight_pounds is_male  mother_age  plurality  gestation_weeks mother_race  \\\n",
       "0       7.749249   False          12  Single(1)               40         1.0   \n",
       "1       7.561856    True          12  Single(1)               40         2.0   \n",
       "2       7.187070   False          12  Single(1)               34         3.0   \n",
       "3       6.375769    True          12  Single(1)               36         2.0   \n",
       "4       7.936641   False          12  Single(1)               35         0.0   \n",
       "\n",
       "   weight_category  \n",
       "0                2  \n",
       "1                2  \n",
       "2                2  \n",
       "3                2  \n",
       "4                2  "
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_labels(classes):\n",
    "    one_hots = to_categorical(classes)\n",
    "    return one_hots\n",
    "\n",
    "FEATURES = ['is_male', 'mother_age', 'plurality', 'gestation_weeks', 'mother_race']\n",
    "\n",
    "LABEL_CLS = ['weight_category']\n",
    "LABEL_REG = ['weight_pounds']\n",
    "\n",
    "N_TRAIN = int(df.shape[0] * 0.80)\n",
    "\n",
    "X_train = df[FEATURES][:N_TRAIN]\n",
    "X_valid = df[FEATURES][N_TRAIN:]\n",
    "\n",
    "y_train_cls = encode_labels(df[LABEL_CLS][:N_TRAIN])\n",
    "y_train_reg = df[LABEL_REG][:N_TRAIN]\n",
    "\n",
    "y_valid_cls = encode_labels(df[LABEL_CLS][N_TRAIN:])\n",
    "y_valid_reg = df[LABEL_REG][N_TRAIN:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create `tf.data` datsets for both classification and regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train/validation dataset for classification model\n",
    "cls_train_data = tf.data.Dataset.from_tensor_slices((X_train.to_dict('list'), y_train_cls))\n",
    "cls_valid_data = tf.data.Dataset.from_tensor_slices((X_valid.to_dict('list'), y_valid_cls))\n",
    "\n",
    "# train/validation dataset for regression model\n",
    "reg_train_data = tf.data.Dataset.from_tensor_slices((X_train.to_dict('list'), y_train_reg.values))\n",
    "reg_valid_data = tf.data.Dataset.from_tensor_slices((X_valid.to_dict('list'), y_valid_reg.values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "({'is_male': <tf.Tensor: shape=(), dtype=string, numpy=b'False'>, 'mother_age': <tf.Tensor: shape=(), dtype=int32, numpy=12>, 'plurality': <tf.Tensor: shape=(), dtype=string, numpy=b'Single(1)'>, 'gestation_weeks': <tf.Tensor: shape=(), dtype=int32, numpy=40>, 'mother_race': <tf.Tensor: shape=(), dtype=string, numpy=b'1.0'>}, <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0., 0., 1., 0.], dtype=float32)>)\n",
      "\n",
      "({'is_male': <tf.Tensor: shape=(), dtype=string, numpy=b'False'>, 'mother_age': <tf.Tensor: shape=(), dtype=int32, numpy=12>, 'plurality': <tf.Tensor: shape=(), dtype=string, numpy=b'Single(1)'>, 'gestation_weeks': <tf.Tensor: shape=(), dtype=int32, numpy=40>, 'mother_race': <tf.Tensor: shape=(), dtype=string, numpy=b'1.0'>}, <tf.Tensor: shape=(1,), dtype=float64, numpy=array([7.74924851])>)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Examine the two datasets. Notice the different label values.\n",
    "for data_type in [cls_train_data, reg_train_data]:\n",
    "    for dict_slice in data_type.take(1):\n",
    "        print(\"{}\\n\".format(dict_slice))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create feature columns to handle categorical variables\n",
    "numeric_columns = [fc.numeric_column(\"mother_age\"),\n",
    "                  fc.numeric_column(\"gestation_weeks\")]\n",
    "\n",
    "CATEGORIES = {\n",
    "    'plurality': list(df.plurality.unique()),\n",
    "    'is_male' : list(df.is_male.unique()),\n",
    "    'mother_race': list(df.mother_race.unique())\n",
    "}\n",
    "\n",
    "categorical_columns = []\n",
    "for feature, vocab in CATEGORIES.items():\n",
    "    cat_col = fc.categorical_column_with_vocabulary_list(\n",
    "        key=feature, vocabulary_list=vocab, dtype=tf.string)\n",
    "    categorical_columns.append(fc.indicator_column(cat_col))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create Inputs for model\n",
    "inputs = {colname: tf.keras.layers.Input(\n",
    "    name=colname, shape=(), dtype=\"float32\")\n",
    "    for colname in [\"mother_age\", \"gestation_weeks\"]}\n",
    "inputs.update({colname: tf.keras.layers.Input(\n",
    "    name=colname, shape=(), dtype=tf.string)\n",
    "    for colname in [\"plurality\", \"is_male\", \"mother_race\"]})\n",
    "\n",
    "# build DenseFeatures for the model\n",
    "dnn_inputs = layers.DenseFeatures(categorical_columns+numeric_columns)(inputs)\n",
    "\n",
    "# create hidden layers\n",
    "h1 = layers.Dense(20, activation=\"relu\")(dnn_inputs)\n",
    "h2 = layers.Dense(10, activation=\"relu\")(h1)\n",
    "\n",
    "# create classification model\n",
    "cls_output = layers.Dense(4, activation=\"softmax\")(h2)\n",
    "cls_model = tf.keras.models.Model(inputs=inputs, outputs=cls_output)\n",
    "cls_model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.CategoricalCrossentropy(),\n",
    "              metrics=['accuracy'])   \n",
    "\n",
    "\n",
    "# create regression model\n",
    "reg_output = layers.Dense(1, activation=\"relu\")(h2)\n",
    "reg_model = tf.keras.models.Model(inputs=inputs, outputs=reg_output)\n",
    "reg_model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.MeanSquaredError(),\n",
    "              metrics=['mse'])   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, train the classification model and examine the validation accuracy. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train for 4234 steps\n",
      "4234/4234 [==============================] - 21s 5ms/step - loss: 0.4958 - accuracy: 0.8475\n",
      "1/1 [==============================] - 1s 609ms/step - loss: 0.9457 - accuracy: 0.6750\n",
      "Validation accuracy for classifcation model: 0.6749759316444397\n"
     ]
    }
   ],
   "source": [
    "# train the classifcation model\n",
    "cls_model.fit(cls_train_data.batch(50), epochs=1)\n",
    "\n",
    "val_loss, val_accuracy = cls_model.evaluate(cls_valid_data.batch(X_valid.shape[0]))\n",
    "print(\"Validation accuracy for classifcation model: {}\".format(val_accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll train the regression model and examine the validation RMSE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train for 4234 steps\n",
      "4234/4234 [==============================] - 33s 8ms/step - loss: 1.0646 - mse: 1.0647\n",
      "1/1 [==============================] - 1s 556ms/step - loss: 1.9008 - mse: 1.9008\n",
      "Validation RMSE for regression model: 1.378703721169823\n"
     ]
    }
   ],
   "source": [
    "# train the classifcation model\n",
    "reg_model.fit(reg_train_data.batch(50), epochs=1)\n",
    "\n",
    "val_loss, val_mse = reg_model.evaluate(reg_valid_data.batch(X_valid.shape[0]))\n",
    "print(\"Validation RMSE for regression model: {}\".format(val_mse**0.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The regression model gives a single numeric prediction of baby weight. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(7.286859, dtype=float32)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = reg_model.predict(x={\"gestation_weeks\": tf.convert_to_tensor([38]),\n",
    "                             \"is_male\": tf.convert_to_tensor([\"True\"]),\n",
    "                             \"mother_age\": tf.convert_to_tensor([28]),\n",
    "                             \"mother_race\": tf.convert_to_tensor([\"1.0\"]),\n",
    "                             \"plurality\": tf.convert_to_tensor([\"Single(1)\"])},\n",
    "                          steps=1).squeeze()\n",
    "preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The classification model predicts a probability for each bucket of values. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7.7168038e-04, 5.1103556e-03, 9.3985993e-01, 5.4258034e-02],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = cls_model.predict(x={\"gestation_weeks\": tf.convert_to_tensor([38]),\n",
    "                             \"is_male\": tf.convert_to_tensor([\"True\"]),\n",
    "                             \"mother_age\": tf.convert_to_tensor([28]),\n",
    "                             \"mother_race\": tf.convert_to_tensor([\"1.0\"]),\n",
    "                             \"plurality\": tf.convert_to_tensor([\"Single(1)\"])},\n",
    "                          steps=1).squeeze()\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEJCAYAAACE39xMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAVwUlEQVR4nO3de7xdZX3n8c/XRKo2FFCCSkCDLaBhKgyE26uKWhVBq6iDBbyCIKIi6FgrOtZhCp0ZxRmcVpgUkKLjJbQKioqgY4ebiFwU5CY2EwQiiomCElEx+Osfax3YnO6cs5Ock5M8fN6v135lrfU8a+3fXuec7372s/beSVUhSdr4PWqmC5AkTQ0DXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa61kmS+UkqyeyZrgUgybOT3DJi3+cmWTbdNa2LJBclOaJffk2Sr67lcb6S5A1TW502NAa6SPKDJL9KsjLJ3Um+nGTbma5rbVTVpVW141QcK8lZSU6cimNNhar6VFXtO1m/JMcn+eS4ffevqo9PX3XaEBjoGvPSqpoDPBm4C/i7Ga6nORvKqxi1y0DXw1TVr4HPAgvGtiV5SZLvJPlFkjuSHD9k1zcmuTPJj5K8q9/vSUnuS/KEgWPtlmR5kkcP7pzkMf2rhC379fcnWZXkD/r1E5N8pF/+vSQfTnJ7kruSLEry2L7tYdMoSXbta783yT8lOXv8qDvJu5L8pK/9sH7bkcBrgL/sX7l8cdj56qebjkmyNMmKJCcleVTfdmiSbyQ5OcnPgOP77W9McnP/aujCJE8dON4Lk3wvyc+TfBTIQNuhSS4bWN8pydeS/Kw/D+9Lsh/wPuCgvu7r+r6DUzeP6s/vbf3j/kSSzfq2sSm0N/Tnd0WS/zTssWvDY6DrYZI8DjgIuGJg8y+B1wObAy8B3pLk5eN2fR6wPbAvcFySF1TVj4GLgD8f6PdaYHFV/XZw5/6J5CrgOf2mfYDbgD8ZWL+4X/4gsAOwC/BHwDzgA0MeyybAucBZwOOBzwCvGNftScBm/TEOB05JskVVnQZ8CvhQVc2pqpeOP/6AVwALgV2BA4A3DrTtCSwFtgL+pj9v7wNeCcwFLu3ron8y+xzwfmBL4P8PPP7xj21T4P8CFwBb9+fh61V1AfBfgbP7uncesvuh/e15wNOAOcBHx/V5FrAj8HzgA0meMcHj14aiqrw9wm/AD4CVwD3AKuBO4I8n6P8R4OR+eT5QwNMH2j8EfKxfPgj4Rr88C/gxsMdqjnsC8LfA7L7fscB/Bx4D/Iou5EL3BPOHA/vtDdzaLz8XWNYv7wP8EMhA38uAEwf6/gqYPdD+E2Cvfvmssb4TnIsC9htYfytdsEIXmreP6/8V4PCB9UcB9wFPpXvSvGKgLcAy4IiB413WLx8CfGc1NR0PfHLctosGjvN14K0DbTsCv+3P+9jPc5uB9iuBg2f699Tb5DdH6Brz8qraHPg94Gjg4iRPAkiyZ5L/10+V/Bw4ii5cB90xsHwb3agR4AvAgiRPA14I/LyqrlxNDRfTheyuwPXA1+hG7HsBS6pqBd2o9nHANUnuSXIP3Sh17pDjbQ38sPpUGlInwE+ratXA+n10I9Y1sbrHPuz+ngr8r4Haf0YX3PP6/R7s39c9fv8x29KN4NfG1n2dgzXPBp44sO3HA8trc040Awx0PUxVPVBV5wAP0L3sBvg0cB6wbVVtBixiYG63N/iumKfQjfKpbirlH+nmo18H/J8J7v5yutHiK4CLq+qm/lgv4aHplhV0o+qdqmrz/rZZdRd0x/sRMC/JYK1r8u6dUb+KdOhjX80x7gDePFD75lX12Kq6vK/3wWP1da+u3juAP1zLuu+ke2IZrHkV3cVwbcQMdD1MOgcAWwA395s3BX5WVb9Osgfw6iG7/lWSxyXZCTgMOHug7RN00wUvAz45ZF8Aquo+4BrgbTwU4JcDbx5br6rfAacDJyfZqq95XpIXDTnkN+memI5OMrt/XHtMcgoG3UU3xzyZdyfZIt1bPY/l4Y99vEXAe/vzRJLNkryqb/sysFOSV6Z7R8wxdHP8w3wJeFKSd/QXiTdNsudA3fPHLs4O8RngnUm2SzKHh+bcV62mvzYSBrrGfDHJSuAXwN8Ab6iqG/u2twJ/neReuouP/zhk/4uBJXTzsx+uqgc/AFNV3wB+B3y7qn4wSR0XA4+mm7cdW98UuGSgz3v6+7oiyS/oLg7+m/eeV9X9dBcfD6e7PvBauiD8zSQ1jPkY3XTRPUk+P0G/L9A9EV1LF8ofW13HqjqX7qLu4r72G4D9+7YVwKvorhv8lO4i8zdWc5x76aawXko3PfIvdBc5Af6p//enSb49ZPcz6V4pXQLcCvwaePsEj08biTx8elGaHkn+Gfh0VZ0xw3V8C1hUVf8wRccrYPuqWjIVx5PWhSN0Tbsku9Nd6JxoKmK67vs56d4PPzvdR9+fSXcRVWqOn1zTtEryceDlwLH9NMH6tiPdFNEcuneFHFhVP5qBOqRp55SLJDXCKRdJasSMTblsueWWNX/+/Jm6e0naKF1zzTUrqmrYB+lmLtDnz5/P1VdfPVN3L0kbpSS3ra7NKRdJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqE37YorYWTv/b9mS5hRr3zhTvMdAkawhG6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRowU6En2S3JLkiVJjhvSvlmSLya5LsmNSQ6b+lIlSROZNNCTzAJOAfYHFgCHJFkwrtvbgJuqamfgucD/SLLJFNcqSZrAKCP0PYAlVbW0qu4HFgMHjOtTwKZJAswBfgasmtJKJUkTGiXQ5wF3DKwv67cN+ijwDOBO4Hrg2Kr63ZRUKEkaySiBniHbatz6i4Brga2BXYCPJvmDf3Og5MgkVye5evny5WtcrCRp9UYJ9GXAtgPr29CNxAcdBpxTnSXArcDTxx+oqk6rqoVVtXDu3LlrW7MkaYhRAv0qYPsk2/UXOg8GzhvX53bg+QBJngjsCCydykIlSRObPVmHqlqV5GjgQmAWcGZV3ZjkqL59EXACcFaS6+mmaN5TVSumsW5J0jiTBjpAVZ0PnD9u26KB5TuBfae2NEnSmvCTopLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiNGCvQk+yW5JcmSJMetps9zk1yb5MYkF09tmZKkycyerEOSWcApwAuBZcBVSc6rqpsG+mwOnArsV1W3J9lqugqWJA03ygh9D2BJVS2tqvuBxcAB4/q8Gjinqm4HqKqfTG2ZkqTJjBLo84A7BtaX9dsG7QBskeSiJNckef1UFShJGs2kUy5AhmyrIcfZDXg+8Fjgm0muqKrvP+xAyZHAkQBPecpT1rxaSdJqjTJCXwZsO7C+DXDnkD4XVNUvq2oFcAmw8/gDVdVpVbWwqhbOnTt3bWuWJA0xSqBfBWyfZLskmwAHA+eN6/MF4NlJZid5HLAncPPUlipJmsikUy5VtSrJ0cCFwCzgzKq6MclRffuiqro5yQXAd4HfAWdU1Q3TWbgk6eFGmUOnqs4Hzh+3bdG49ZOAk6auNEnSmvCTopLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqxEiBnmS/JLckWZLkuAn67Z7kgSQHTl2JkqRRTBroSWYBpwD7AwuAQ5IsWE2/DwIXTnWRkqTJjTJC3wNYUlVLq+p+YDFwwJB+bwc+B/xkCuuTJI1olECfB9wxsL6s3/agJPOAVwCLJjpQkiOTXJ3k6uXLl69prZKkCYwS6BmyrcatfwR4T1U9MNGBquq0qlpYVQvnzp07ao2SpBHMHqHPMmDbgfVtgDvH9VkILE4CsCXw4iSrqurzU1KlJGlSowT6VcD2SbYDfggcDLx6sENVbTe2nOQs4EuGuSStX5MGelWtSnI03btXZgFnVtWNSY7q2yecN5ckrR+jjNCpqvOB88dtGxrkVXXoupclSVpTflJUkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqxEiBnmS/JLckWZLkuCHtr0ny3f52eZKdp75USdJEJg30JLOAU4D9gQXAIUkWjOt2K/CcqnomcAJw2lQXKkma2Cgj9D2AJVW1tKruBxYDBwx2qKrLq+rufvUKYJupLVOSNJlRAn0ecMfA+rJ+2+ocDnxlWEOSI5NcneTq5cuXj16lJGlSowR6hmyroR2T59EF+nuGtVfVaVW1sKoWzp07d/QqJUmTmj1Cn2XAtgPr2wB3ju+U5JnAGcD+VfXTqSlPkjSqUUboVwHbJ9kuySbAwcB5gx2SPAU4B3hdVX1/6suUJE1m0hF6Va1KcjRwITALOLOqbkxyVN++CPgA8ATg1CQAq6pq4fSVLUkab5QpF6rqfOD8cdsWDSwfARwxtaVJktaEnxSVpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUiNkzXYCkR56Tv/b9mS5hRr3zhTtMy3EdoUtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaMVKgJ9kvyS1JliQ5bkh7kvxt3/7dJLtOfamSpIlMGuhJZgGnAPsDC4BDkiwY121/YPv+diTwv6e4TknSJEb56P8ewJKqWgqQZDFwAHDTQJ8DgE9UVQFXJNk8yZOr6kdTXrGmhB+9np6PXkszaZRAnwfcMbC+DNhzhD7zgIcFepIj6UbwACuT3LJG1W44tgRWzHQRG7kZPYf/cabueOp4/tbNxnz+nrq6hlECPUO21Vr0oapOA04b4T43aEmurqqFM13HxsxzuG48f+um1fM3ykXRZcC2A+vbAHeuRR9J0jQaJdCvArZPsl2STYCDgfPG9TkPeH3/bpe9gJ87fy5J69ekUy5VtSrJ0cCFwCzgzKq6MclRffsi4HzgxcAS4D7gsOkreYOw0U8bbQA8h+vG87dumjx/6d6YIkna2PlJUUlqhIEuSY0w0DUtkqyc6Rr0yJNkfpIbhmz/6yQvmGTf45P8xfRVN/38T6KnUJKVVTVnpuvQI0OSWVX1wEzXsTGoqg/MdA3rgyP01ejfgun5WUf9eTwpyQ1Jrk9yUL/91CQv65fPTXJmv3x4khNnsubpkuTzSa5JcmOSI5O8JcmHBtoPTfJ3/fJrk1yZ5Nokf99/pxJJVvajzW8Beyf5QJKr+vN7WpL0/Xbvvyjvm2Pnv98+q1+/qm9/8wyciuk2K8np/Xn+apLHJjkryYEASV6c5HtJLuu/VPBLA/suSHJRkqVJjpmh+tda84GV5INJ3jqwfnySdyV598Av9X/p2+YnuTnJqcC3gb9KcvLAvm9K8j9HuE9D7CGvBHYBdgZeAJyU5MnAJcCz+z7z6L74DeBZwKXru8j15I1VtRuwEDgGOIfu/Iw5CDg7yTP65T+pql2AB4DX9H1+H7ihqvasqsuAj1bV7lX174DHAn/W9/sH4Kiq2rvff8zhdJ8T2R3YHXhTku2m48HOoO2BU6pqJ+Ae4D+MNSR5DPD3wP5V9Sxg7rh9nw68iO47rP5zkkevn5KnRvOBDiym++MY8+fAcrof+h50YbNbkn369h3pvmjs3wMfBl428EM9jO4PZTKG2EOeBXymqh6oqruAi+mC5FLg2em+ufMm4K7+HO0NXD5j1U6vY5JcB1xB98nq7YClSfZK8gS6371vAM8HdgOuSnJtv/60/hgPAJ8bOObzknwryfXAnwI7Jdkc2LSqxs7jpwf670v3IcBrgW8BT6D7W2jJrVV1bb98DTB/oO3pwNKqurVf/8y4fb9cVb+pqhXAT4AnTmulU6z5OfSq+k6SrZJsTfdsfDfwTLpf7O/03ebQ/VLfDtxWVVf0+/4yyT8Df5bkZuDRVXX9CHf7YIjRBdVgiL1jIMS2GAixje7l3YiGfc8PVfXDJFsA+9E90T2e7sl2ZVXdux7rWy+SPJfuyX3vqrovyUXAY4Cz6R7394Bzq6r6aZOPV9V7hxzq12Pz5v1o81RgYVXdkeT4/phDz/lYKcDbq+rCqXlkG6TfDCw/QPfKZcxE52bYvhtVRj4SRugAnwUOpBupL6b7of63qtqlv/1RVX2s7/vLcfueARzK6KNzmCDEgMEQu5SGQ6x3CXBQP3c7F9gHuLJv+ybwDh46F39Bu69UNgPu7sP86cBe/fZzgJcDh9CFO8DXgQOTbAWQ5PFJhn3D3mP6f1ckmUP3O05V3Q3cm+5rOKD7uo4xFwJvGXvVmWSHJL8/JY9w4/A94GlJ5vfrB62+68Zno3r2WQeLgdPpvjLzOcAfAyck+VRVrUwyD/jtsB2r6ltJtgV2pRvZj+IS4M1JPk438twHeHffNhZif0r3cvez/a1V59K9ArmO7hs4/7Kqfty3XQrsW1VLktxGd65aDfQLgKOSfBe4hW7ahaq6O8lNwIKqurLfdlOS9wNfTXdh/rfA24DbBg9YVfckOR24HvgB3fcujTkcOD3JL4GLgJ/328+gm4L4dv9KYDndE8ojQlX9qr+mdkGSFTw0uGjCI+aj//0c44qqel6/fixwRN+8Engt3UusL/UXmAb3PQ7YpaoGRzrD7mNlVc3p/1A+RPc/ORVwYlWd3fc5HDihqrbuR0n3AK+rqnOm6rFKSeZU1cp++TjgyVV17AyXtUEYOzf93+kpwL9U1cmT7bcxeMQE+rro39Z0clV9faZrkUbRv7PqvXSvwm8DDq2q5TNb1YYhyTuBNwCb0F1He1NV3TezVU0NA30C/bsFrgSuq6pXzXQ9kjQRA30N9W8vGzZSf35V/XR91yNJYwx0SWrEI+Vti5LUPANdkhphoEtSIwx0SWrEvwIlDCdS7bOA1gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "objects = ('very_low', 'low', 'average', 'high')\n",
    "y_pos = np.arange(len(objects))\n",
    "predictions = list(preds)\n",
    "\n",
    "plt.bar(y_pos, predictions, align='center', alpha=0.5)\n",
    "plt.xticks(y_pos, objects)\n",
    "plt.title('Baby weight prediction')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Increasing the number of  categorical labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll generalize the code above to accommodate `N` label buckets, instead of just 4. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read in the data and preprocess\n",
    "df = pd.read_csv(\"./data/babyweight_train.csv\")\n",
    "\n",
    "# prepare inputs\n",
    "df.is_male = df.is_male.astype(str)\n",
    "\n",
    "df.mother_race.fillna(0, inplace = True)\n",
    "df.mother_race = df.mother_race.astype(str)\n",
    "    \n",
    "# create categorical label\n",
    "MIN = np.min(df.weight_pounds)\n",
    "MAX = np.max(df.weight_pounds)\n",
    "NBUCKETS = 50\n",
    "\n",
    "def categorical_weight(weight_pounds, weight_min, weight_max, nbuckets=10):\n",
    "    buckets = np.linspace(weight_min, weight_max, nbuckets)\n",
    "    \n",
    "    return np.digitize(weight_pounds, buckets) - 1\n",
    "\n",
    "df[\"weight_category\"] = df.weight_pounds.apply(lambda x: categorical_weight(x, MIN, MAX, NBUCKETS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_labels(classes):\n",
    "    one_hots = to_categorical(classes)\n",
    "    return one_hots\n",
    "\n",
    "FEATURES = ['is_male', 'mother_age', 'plurality', 'gestation_weeks', 'mother_race']\n",
    "LABEL_COLUMN = ['weight_category']\n",
    "\n",
    "N_TRAIN = int(df.shape[0] * 0.80)\n",
    "\n",
    "X_train, y_train = df[FEATURES][:N_TRAIN], encode_labels(df[LABEL_COLUMN][:N_TRAIN])\n",
    "X_valid, y_valid = df[FEATURES][N_TRAIN:], encode_labels(df[LABEL_COLUMN][N_TRAIN:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the training dataset\n",
    "train_data = tf.data.Dataset.from_tensor_slices((X_train.to_dict('list'), y_train))\n",
    "valid_data = tf.data.Dataset.from_tensor_slices((X_valid.to_dict('list'), y_valid))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the feature columns and build the model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create feature columns to handle categorical variables\n",
    "numeric_columns = [fc.numeric_column(\"mother_age\"),\n",
    "                  fc.numeric_column(\"gestation_weeks\")]\n",
    "\n",
    "CATEGORIES = {\n",
    "    'plurality': list(df.plurality.unique()),\n",
    "    'is_male' : list(df.is_male.unique()),\n",
    "    'mother_race': list(df.mother_race.unique())\n",
    "}\n",
    "\n",
    "categorical_columns = []\n",
    "for feature, vocab in CATEGORIES.items():\n",
    "    cat_col = fc.categorical_column_with_vocabulary_list(\n",
    "        key=feature, vocabulary_list=vocab, dtype=tf.string)\n",
    "    categorical_columns.append(fc.indicator_column(cat_col))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create Inputs for model\n",
    "inputs = {colname: tf.keras.layers.Input(\n",
    "    name=colname, shape=(), dtype=\"float32\")\n",
    "    for colname in [\"mother_age\", \"gestation_weeks\"]}\n",
    "inputs.update({colname: tf.keras.layers.Input(\n",
    "    name=colname, shape=(), dtype=tf.string)\n",
    "    for colname in [\"plurality\", \"is_male\", \"mother_race\"]})\n",
    "\n",
    "# build DenseFeatures for the model\n",
    "dnn_inputs = layers.DenseFeatures(categorical_columns+numeric_columns)(inputs)\n",
    "\n",
    "# model\n",
    "h1 = layers.Dense(20, activation=\"relu\")(dnn_inputs)\n",
    "h2 = layers.Dense(10, activation=\"relu\")(h1)\n",
    "output = layers.Dense(NBUCKETS, activation=\"softmax\")(h2)\n",
    "model = tf.keras.models.Model(inputs=inputs, outputs=output)\n",
    "\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.CategoricalCrossentropy(),\n",
    "              metrics=['accuracy'])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train for 4234 steps\n",
      "4234/4234 [==============================] - 20s 5ms/step - loss: 2.5945 - accuracy: 0.1329\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f6c32bc1e90>"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# train the model\n",
    "model.fit(train_data.batch(50), epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make a prediction on the example above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model.predict(x={\"gestation_weeks\": tf.convert_to_tensor([38]),\n",
    "                         \"is_male\": tf.convert_to_tensor([\"True\"]),\n",
    "                         \"mother_age\": tf.convert_to_tensor([28]),\n",
    "                         \"mother_race\": tf.convert_to_tensor([\"1.0\"]),\n",
    "                         \"plurality\": tf.convert_to_tensor([\"Single(1)\"])},\n",
    "                      steps=1).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEICAYAAABWJCMKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3df7xVVZ3/8ddblMzfqNckZMT8kkUzxRChM5Xpw9EBM9GZnPRbZmpftJHyR84MY1PZVDP+Khonv/LQZNLJMqtxomJE8pumqcWPSEFC0VCuIFxFRQVB5PP9Y63jXWwO957LvfLD/X4+Hudxz157rbXXOvvHZ++19zlXEYGZmdXPDlu7AWZmtnU4AJiZ1ZQDgJlZTTkAmJnVlAOAmVlNOQCYmdWUA4BtcZKGSApJO27ttgBIer+kBS3mPUJS+2vdpt6QdIekT+b3H5V022bW8z+STuvb1tm2xAHANoukRZJWS3pB0jOSfiZp8NZu1+aIiLsi4pC+qEvStyV9pS/q6gsRcWNEHNNdPkkXS/pOpeyYiLj+tWudbW0OANYbH4qI3YCBwDLg37dye153tpWrJHt9cgCwXouIl4AfAsMaaZI+KOm3klZKWizp4iZFz5C0RNJSSZ/N5faXtErSPkVd75bUIWmnsrCknfNVyL55+p8krZO0R57+iqRv5PdvkHSFpMclLZM0SdIb87wNhnUkjchtf17SDyR9v3pWL+mzkpbntp+e08YBHwX+Pl8Z/aTZ55WHvz4j6VFJT0m6XNIOed4nJP1K0kRJK4CLc/oZkubnq61pkg4s6jta0u8lPSfpm4CKeZ+QdHcx/Q5J0yWtyJ/DRZJGAxcBH8nt/l3OWw4l7ZA/38dyv2+QtGee1xjSOy1/vk9J+lyzvtu2xQHAek3SLsBHgPuK5BeBjwN7AR8EPiXphErRI4GhwDHABEl/ERFPAncAf1Pk+xhwU0S8XBbOgWcG8IGcdDjwGPDeYvrO/P5S4K3AcOB/AYOALzTpS3/gFuDbwN7A94ATK9n2B/bMdZwJXCVpQERcA9wIXBYRu0XEh6r1F04ERgIjgLHAGcW8Q4FHgf2Ar+bP7SLgr4A24K7cLnLw+xHwT8C+wCNF/6t92x34OXAr8Ob8OdweEbcC/wJ8P7f7XU2KfyK/jgTeAuwGfLOS533AIcBRwBckvb2L/tu2ICL88qvHL2AR8ALwLLAOWAL8SRf5vwFMzO+HAAG8rZh/GXBdfv8R4Ff5fT/gSWDUJur9MnAlsGPOdy5wCbAzsJp0UBQpIB1clPsz4A/5/RFAe35/OPAEoCLv3cBXiryrgR2L+cuBw/L7bzfydvFZBDC6mP5b0oEY0kH28Ur+/wHOLKZ3AFYBB5KC7H3FPAHtwCeL+u7O708BfruJNl0MfKeSdkdRz+3A3xbzDgFezp97Y30eUMz/DXDy1t5O/er65SsA640TImIv4A3AeOBOSfsDSDpU0i/y0M1zwNmkg3FpcfH+MdJZKcCPgWGS3gIcDTwXEb/ZRBvuJB2URwAPANNJVwSHAQsj4inSWfMuwCxJz0p6lnQW3NakvjcDT0Q+ijVpJ8DTEbGumF5FOiPuiU31vdnyDgT+rWj7CtKBflAu92r+3O5q+YbBpCuEzfHm3M6yzTsCbyrSnizeb85nYluYA4D1WkS8EhH/BbxCGgYA+C4wBRgcEXsCkyjGprPyqaE/Il1FEGlo52bSePqpwH92sfh7SGejJwJ3RsSDua4P0jn88xTprP0dEbFXfu0Z6QZ21VJgkKSyrT15uqnVn9dt2vdN1LEYOKto+14R8caIuCe399W6crs31d7FwMGb2e4lpEBUtnkd6ea/baccAKzXlIwFBgDzc/LuwIqIeEnSKOB/Nyn6eUm7SHoHcDrw/WLeDaThi+OB7zQpC0BErAJmAefQecC/BzirMR0R64FrgYmS9sttHiTpL5tUeS8pkI2XtGPu16huPoLSMtIYeXf+TtIApUdnz2XDvldNAv4xf05I2lPSSXnez4B3SPorpSeGPkO6R9HMT4H9JZ2Xb4rvLunQot1DGjejm/gecL6kgyTtRuc9g3WbyG/bAQcA642fSHoBWAl8FTgtIubleX8L/LOk50k3W29uUv5OYCFpfPmKiHj1C0sR8StgPTA7IhZ10447gZ1I486N6d2BXxZ5/iEv6z5JK0k3Qzd69j8i1pJutp5Jur/xMdKBc003bWi4jjR89ayk/+4i349JgWsO6SB+3aYyRsQtpJvYN+W2zwXG5HlPASeR7ns8Tbqp/qtN1PM8aUjtQ6ThmodJN3UBfpD/Pi1pdpPik0lXYr8E/gC8BHy6i/7ZdkAbDnWabTsk/T/guxHxra3cjl8DkyLiP/qovgCGRsTCvqjPbHP5CsC2SZLeQ7qx29XQyGu17A8ofR9hR6WfQngn6aax2euKv2Vo2xxJ1wMnAOfmYYst7RDSkNVupKdmPhwRS7dCO8xeUx4CMjOrKQ8BmZnV1HY1BLTvvvvGkCFDtnYzzMy2K7NmzXoqIjb64uN2FQCGDBnCzJkzt3YzzMy2K5Iea5buISAzs5pyADAzqykHADOzmnIAMDOrKQcAM7OacgAwM6spBwAzs5pyADAzqykHADOzmtquvglsti2YOP2hjdLOP/qtW6ElZr3jKwAzs5pyADAzqykHADOzmvI9AKs9j+lbXfkKwMysphwAzMxqygHAzKymHADMzGrKAcDMrKYcAMzMaqqlACBptKQFkhZKmtBk/tsk3StpjaQLi/RDJM0pXislnZfnXSzpiWLesX3XLTMz60633wOQ1A+4CjgaaAdmSJoSEQ8W2VYAnwFOKMtGxAJgeFHPE8AtRZaJEXFFr3pgZmabpZUrgFHAwoh4NCLWAjcBY8sMEbE8ImYAL3dRz1HAIxHx2Ga31szM+kwrAWAQsLiYbs9pPXUy8L1K2nhJ90uaLGlAs0KSxkmaKWlmR0fHZizWzMyaaSUAqEla9GQhkvoDxwM/KJKvBg4mDREtBb7WrGxEXBMRIyNiZFtbW08Wa2ZmXWglALQDg4vpA4AlPVzOGGB2RCxrJETEsoh4JSLWA9eShprMzGwLaSUAzACGSjoon8mfDEzp4XJOoTL8I2lgMXkiMLeHdZqZWS90+xRQRKyTNB6YBvQDJkfEPEln5/mTJO0PzAT2ANbnRz2HRcRKSbuQniA6q1L1ZZKGk4aTFjWZb2Zmr6GWfg46IqYCUytpk4r3T5KGhpqVXQXs0yT91B611MzM+pS/CWxmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZTDgBmZjXlAGBmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZTDgBmZjXlAGBmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZTDgBmZjXlAGBmVlMtBQBJoyUtkLRQ0oQm898m6V5JayRdWJm3SNIDkuZImlmk7y1puqSH898Bve+OmZm1qtsAIKkfcBUwBhgGnCJpWCXbCuAzwBWbqObIiBgeESOLtAnA7RExFLg9T5uZ2RbSyhXAKGBhRDwaEWuBm4CxZYaIWB4RM4CXe7DsscD1+f31wAk9KGtmZr3USgAYBCwupttzWqsCuE3SLEnjivQ3RcRSgPx3v2aFJY2TNFPSzI6Ojh4s1szMutJKAFCTtOjBMt4bESNIQ0jnSDq8B2WJiGsiYmREjGxra+tJUTMz60IrAaAdGFxMHwAsaXUBEbEk/10O3EIaUgJYJmkgQP67vNU6zcys91oJADOAoZIOktQfOBmY0krlknaVtHvjPXAMMDfPngKclt+fBvy4Jw03M7Pe2bG7DBGxTtJ4YBrQD5gcEfMknZ3nT5K0PzAT2ANYL+k80hND+wK3SGos67sRcWuu+hLgZklnAo8DJ/Vt18zMrCvdBgCAiJgKTK2kTSreP0kaGqpaCbxrE3U+DRzVckvNzKxP+ZvAZmY15QBgZlZTDgBmZjXlAGBmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZTDgBmZjXlAGBmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZTDgBmZjXV0v8DMHs9mDj9oQ2mzz/6rVupJWbbBl8BmJnVlAOAmVlNtRQAJI2WtEDSQkkTmsx/m6R7Ja2RdGGRPljSLyTNlzRP0rnFvIslPSFpTn4d2zddMjOzVnR7D0BSP+Aq4GigHZghaUpEPFhkWwF8BjihUnwd8NmImC1pd2CWpOlF2YkRcUWve2FmZj3WyhXAKGBhRDwaEWuBm4CxZYaIWB4RM4CXK+lLI2J2fv88MB8Y1CctNzOzXmklAAwCFhfT7WzGQVzSEOBPgV8XyeMl3S9psqQBmyg3TtJMSTM7Ojp6ulgzM9uEVgKAmqRFTxYiaTfgR8B5EbEyJ18NHAwMB5YCX2tWNiKuiYiRETGyra2tJ4s1M7MutBIA2oHBxfQBwJJWFyBpJ9LB/8aI+K9GekQsi4hXImI9cC1pqMnMzLaQVgLADGCopIMk9QdOBqa0UrkkAdcB8yPi65V5A4vJE4G5rTXZzMz6QrdPAUXEOknjgWlAP2ByRMyTdHaeP0nS/sBMYA9gvaTzgGHAO4FTgQckzclVXhQRU4HLJA0nDSctAs7q266ZmVlXWvopiHzAnlpJm1S8f5I0NFR1N83vIRARp7beTDMz62v+JrCZWU05AJiZ1ZQDgJlZTTkAmJnVlAOAmVlNOQCYmdWUA4CZWU05AJiZ1ZQDgJlZTTkAmJnVlAOAmVlNOQCYmdWUA4CZWU05AJiZ1ZQDgJlZTTkAmJnVlAOAmVlNOQCYmdWUA4CZWU21FAAkjZa0QNJCSROazH+bpHslrZF0YStlJe0tabqkh/PfAb3vjpmZtarbACCpH3AVMAYYBpwiaVgl2wrgM8AVPSg7Abg9IoYCt+dpMzPbQlq5AhgFLIyIRyNiLXATMLbMEBHLI2IG8HIPyo4Frs/vrwdO2Mw+mJnZZmglAAwCFhfT7TmtFV2VfVNELAXIf/drVoGkcZJmSprZ0dHR4mLNzKw7rQQANUmLFuvvTdmUOeKaiBgZESPb2tp6UtTMzLrQSgBoBwYX0wcAS1qsv6uyyyQNBMh/l7dYp5mZ9YFWAsAMYKikgyT1B04GprRYf1dlpwCn5fenAT9uvdlmZtZbO3aXISLWSRoPTAP6AZMjYp6ks/P8SZL2B2YCewDrJZ0HDIuIlc3K5qovAW6WdCbwOHBSX3fOzMw2rdsAABARU4GplbRJxfsnScM7LZXN6U8DR/WksWZm1nf8TWAzs5pyADAzq6mWhoDM6mji9Ic2Sjv/6LduhZaYvTZ8BWBmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZTDgBmZjXlAGBmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZT/i0gsz7i3w6y7Y2vAMzMasoBwMysphwAzMxqqqUAIGm0pAWSFkqa0GS+JF2Z598vaUROP0TSnOK1Mv+/YCRdLOmJYt6xfds1MzPrSrc3gSX1A64CjgbagRmSpkTEg0W2McDQ/DoUuBo4NCIWAMOLep4AbinKTYyIK/qiI2Zm1jOtXAGMAhZGxKMRsRa4CRhbyTMWuCGS+4C9JA2s5DkKeCQiHut1q83MrNdaCQCDgMXFdHtO62mek4HvVdLG5yGjyZIGNFu4pHGSZkqa2dHR0UJzzcysFa0EADVJi57kkdQfOB74QTH/auBg0hDRUuBrzRYeEddExMiIGNnW1tZCc83MrBWtBIB2YHAxfQCwpId5xgCzI2JZIyEilkXEKxGxHriWNNRkZmZbSCsBYAYwVNJB+Uz+ZGBKJc8U4OP5aaDDgOciYmkx/xQqwz+VewQnAnN73HozM9ts3T4FFBHrJI0HpgH9gMkRMU/S2Xn+JGAqcCywEFgFnN4oL2kX0hNEZ1WqvkzScNJQ0aIm883M7DXU0m8BRcRU0kG+TJtUvA/gnE2UXQXs0yT91B611MzM+pS/CWxmVlMOAGZmNeWfg7bXFf8ks1nrfAVgZlZTDgBmZjXlAGBmVlMOAGZmNeUAYGZWUw4AZmY15QBgZlZTDgBmZjXlAGBmVlP+JrBtl/yNX7PecwAwe405WNm2ykNAZmY15QBgZlZTDgBmZjXlAGBmVlMOAGZmNdVSAJA0WtICSQslTWgyX5KuzPPvlzSimLdI0gOS5kiaWaTvLWm6pIfz3wF90yUzM2tFtwFAUj/gKmAMMAw4RdKwSrYxwND8GgdcXZl/ZEQMj4iRRdoE4PaIGArcnqfNzGwLaeUKYBSwMCIejYi1wE3A2EqescANkdwH7CVpYDf1jgWuz++vB07oQbvNzKyXWgkAg4DFxXR7Tms1TwC3SZolaVyR500RsRQg/92v2cIljZM0U9LMjo6OFpprZmataCUAqEla9CDPeyNiBGmY6BxJh/egfUTENRExMiJGtrW19aSomZl1oZUA0A4MLqYPAJa0miciGn+XA7eQhpQAljWGifLf5T1tvJmZbb5WAsAMYKikgyT1B04GplTyTAE+np8GOgx4LiKWStpV0u4AknYFjgHmFmVOy+9PA37cy76YmVkPdPtjcBGxTtJ4YBrQD5gcEfMknZ3nTwKmAscCC4FVwOm5+JuAWyQ1lvXdiLg1z7sEuFnSmcDjwEl91iszM+tWS78GGhFTSQf5Mm1S8T6Ac5qUexR41ybqfBo4qieNNTOzvuNvApuZ1ZQDgJlZTTkAmJnVlP8jmG3T/N+0zF47vgIwM6spBwAzs5pyADAzqykHADOzmnIAMDOrKQcAM7OacgAwM6spfw/AbCvxdxxsa/MVgJlZTTkAmJnVlAOAmVlNOQCYmdWUA4CZWU05AJiZ1VRLAUDSaEkLJC2UNKHJfEm6Ms+/X9KInD5Y0i8kzZc0T9K5RZmLJT0haU5+Hdt33TIzs+50+z0ASf2Aq4CjgXZghqQpEfFgkW0MMDS/DgWuzn/XAZ+NiNmSdgdmSZpelJ0YEVf0XXfMzKxVrVwBjAIWRsSjEbEWuAkYW8kzFrghkvuAvSQNjIilETEbICKeB+YDg/qw/WZmtplaCQCDgMXFdDsbH8S7zSNpCPCnwK+L5PF5yGiypAHNFi5pnKSZkmZ2dHS00FwzM2tFKwFATdKiJ3kk7Qb8CDgvIlbm5KuBg4HhwFLga80WHhHXRMTIiBjZ1tbWQnPNzKwVrQSAdmBwMX0AsKTVPJJ2Ih38b4yI/2pkiIhlEfFKRKwHriUNNZmZ2RbSSgCYAQyVdJCk/sDJwJRKninAx/PTQIcBz0XEUkkCrgPmR8TXywKSBhaTJwJzN7sXZmbWY90+BRQR6ySNB6YB/YDJETFP0tl5/iRgKnAssBBYBZyei78XOBV4QNKcnHZRREwFLpM0nDRUtAg4q896ZWZm3Wrp56DzAXtqJW1S8T6Ac5qUu5vm9weIiFN71FIzM+tT/iawmVlNOQCYmdWUA4CZWU05AJiZ1ZT/J7BtE/z/cc22PF8BmJnVlAOAmVlNeQjIbBvj4TDbUnwFYGZWUw4AZmY15SEg26I8vGG27fAVgJlZTTkAmJnVlIeAzLYTHj6zvuYrADOzmnIAMDOrKQcAM7OacgAwM6sp3wS214RvWG5Z1c/bn7W1oqUrAEmjJS2QtFDShCbzJenKPP9+SSO6Kytpb0nTJT2c/w7omy6ZmVkrur0CkNQPuAo4GmgHZkiaEhEPFtnGAEPz61DgauDQbspOAG6PiEtyYJgA/EPfdc22BJ/pb9u8fqwrrQwBjQIWRsSjAJJuAsYCZQAYC9wQEQHcJ2kvSQOBIV2UHQsckctfD9yBA8A2yweS1xevTwNQOmZ3kUH6MDA6Ij6Zp08FDo2I8UWenwKXRMTdefp20sF8yKbKSno2IvYq6ngmIjYaBpI0DhiXJw8BFmxuZwv7Ak9tx+nbYpvct+7Tt8U2uW/dp2+rbeqJAyOibaPUiOjyBZwEfKuYPhX490qenwHvK6ZvB97dVVng2Uodz3TXlr56ATO35/RtsU3um/u2rbXp9dy3vnq1chO4HRhcTB8ALGkxT1dll+VhIvLf5S20xczM+kgrAWAGMFTSQZL6AycDUyp5pgAfz08DHQY8FxFLuyk7BTgtvz8N+HEv+2JmZj3Q7U3giFgnaTwwDegHTI6IeZLOzvMnAVOBY4GFwCrg9K7K5qovAW6WdCbwOGm4aEu5ZjtP35rLdt82P31rLtt92/z0rbnsrtrUa93eBDYzs9cn/xSEmVlNOQCYmdXVa/mI0bb2AkaTvkewEJhQpE8mPYU0t0gbDPwCmA/MA87N6TsDvwF+l9O/VFlGP+C3wE8r6YuAB4A5FI92AXsBPwR+n5f1Z6TvO8wpXiuB83L+8/Ny5wLfA3bO6efmtGeA5yt92RuYDjwHrAUeLOadlMsE6Ut7jfTLgWeBdXn5e+X0LwMrgJfzct5c+Ryfz3Xtm9MuBl7M+VcDxxb57831vwRcltO+Dzyd868F5uT04XkdNeoZldPfBcwCXsjLnl+sqz/ObV2b50/I6WfnNgXwSJH/8rxtNKvr33L6S/nv5yvbybJc30U5/evAmpx/NTCpyP9QMe/2nD4lL/Ol3N72nP6Xeb016vlaTn9Pzr86//3XnD4w51+T23lpsZ4fzG18iGLbzW1dnV8rSd/pAfgX0j291bmuxrIb+8CSXF9j3X0lt71R141F/kW5Dy8Bv8rpP8jrYXUu92ROH5WX16jn2mJd30vaj54Dbq1s3w/n9k8r+jwPWE/av35arOffA/eTtvFbi237ftI+d1teTnU//rvc58YyLgaeyGVWAb8u8n6adLxZTd63SNt3Y79eQ3pgBtL2fV9Rz91N+vwTYI8+PSZu7YPylnqRDsyPAG8B+pMO4MPyvMOBEWx40BwIjMjvd887zTBAwG45fSfg18BhRbkLgO822XAWkQ+KlfTrgU/m9/3JB9pKu58EDgQGAX8A3pjn3Qx8gnSgmwvsAhyZ2/RQUcdlpJ/aOBy4Eugo5r0d+Cgwkw0DwDG5rhFAB50Hkj2Kz2sJ+cCW5/01cA9pZy4DwFVNPt8j8zIPzW3fr5jXqP8p4As57TbSzjcif5Z35PQZwIk5/Qzg0mJdXQVcmfN9gRQMhgHvz2XuyMtq5D+G9KjyiFzPxGLe0GJ7+DvSgWMYaTsZQ3rQ4XFSABkGXAFMbLL9/HVeP2/I6Y8U9TTqvzL3fVhu4/ji812V02eQvmQJ8H9IB6HD8rpufGYX5XV0WF7PhwB3ASMptt3c7z1zmcuLMnvQua2fRwpyh5H2gUNynx8jBeDD8rpuBMCy/iNJQfINOX1WUU+j/omkx8YPy+v6xJz+IdJB/bDc5w+Q9rH7yNs4ndv3BaQD6MJi2z4kr5Nb6QwAx5AegLmAFBQb+fcotsFbSAHlp0XaYFLgeJENA8CFVPb73Oefk7aV7wK3VfbrC0gnGAuK7XtMTr8DeLrYvj+Q358BfLkvj4t1GgJ69SctImIt0PhZCiLil6SDw6siYmlEzM7vG2eDgyJ5IWfbKb8CQNIBwAeBb7XSIEmNg+l1eTlrI+LZSrajgEci4rE8vSPwRkk7kg74S0gb+n0RsSoifgHcTdp5G8YC1+d+/mc5LyLmR8SNpDPxsv+35bpWkA46B+T0lcXntUOj79kppLOoqj9Q+XyBT5F22mW53le/B1LUvyfpKoe8nMdzej86v09yCPDfeV1NJx0w5pOC5V8A/5rzXUtaV4Mi4q6IuCWnr6Jz3d4WEe25rvuA/Yp5Dze2B9I6eCanLwXOBP6edKb5cF72C6SD8gbbD/AR4HMRsSanz23UExGzJYkUnObk/GtIARjSCcKKnN44AEM6YOydP6Oxua+QtvG9UxNifkQsAF7J817ddnO/n8vpM0ln7JHXdWNb34O8viMdjb6a+9z4PBrbwdpq/aR1/dWIWJPT1KgnIl7IfT6JdFYf+dU/17M3aduM3OdHSPvYRFLQJPf5tpz+FWD//LnPJx2s9wH+O+clIm7LeT6YP6Odc/pKeHU//hPSOitNym18vpK+Jxvv95/K06Pz38bnUh4n9iFvI7l/Q3L63aQrJXKff5nfTyedBPSdvowm2/IL+DAbfyv5m8X0EIoz1ErZIaSDzx55uh9pB331Ejun/5D0Degj2PgK4A/AbNLZz7icNpx0Kf1t0rDRt4BdK+Umk88A8/S5ebkddF5iv510hrkPKSjMJp9B5PnPVvrySpM+3kdxBVDJvxL4WJH2VdIB+CWgLacdTxomGcLGVwCLSDvTM8CAnD4H+FLu94vAeyrL/RtgdTH99rwOlpCGgQ7M6fcAY6PzrOqFxrpq0u/1bHiWdwcpYDxO5dKadLl9XmW9fxVYTLqsb8/LOB74tzy/vUhv9Pt+0pXa4pze6Pev82f+ZKVNh+cyjT40+r04530ip99DZ6BYA6xprGs23D5fqvTrjtz+6rbbKLMO+EmR/i95fb5C59XU8aSrlDn5M72ysq5X53X0jWJd/3Nez6+QTkbK5T6c0y+trOu1uf6ri3X9K9I+dhWwruhzue+treyTM4Gz2PBsvpH/HuC3lW17FemqYSydZ/TH53a+O6+H8grgxTxvapE+h3R10Rj2vauy7E/m9fDTos+rct0ddA4NVrfv5/vyuFinKwA1SYsmaRsWknYDfkQag18JEBGvRMRw0lnxKEl/LOk4YHlEzNpEVe+NiBGky7xzJB1OOnMaQdrA/5S0IZU/md2ftOH9IE8PIG2UBwFvBnaV9LFIZzqXks4QbiUdbLvtW4vOyX9vbCRExOeAPyfteOMl7QJ8jjTMUnU1cDDpeyIvA1/L6TsCA0gHsSdJ3wkp19Hxuf6GT5Huf/w5sJR81US6LD5H0ixSANyZYl3BButwVZlOOvhc1iT/50if36lsuN4/R9pRdyMdiNY1+p2X0UY6u19Z9Pu9pCGnh3J6o99Hkc4cxYZnlKfm+Y3lNvrdOEA8ldPPyPNeIY3h7yDpj3M7y+2zXyO98FGKbbdRhrSd3QrsVqRfFBH9SWfWH5L0ztznz+dltAMjcv5Gn3clnQx8JKfvSLrXtRvpl4H/prLcn5O2nUZ7PgWcn5d7RlHPDcAfkZ6N34UUHMj1b7TvNfZJUrBrlj46f35PFLPvJZ2QXUfaNnExRAkAAAUzSURBVMnb9+XAnU3274W5XYeQ7l29PacPIG0f7yRdObw7f1G2sex3k34yp+FS4JcRsT8puL0zp5fb9+4UVxJ9oi+jybb8It1cnVZM/yPwj8X0ECpXAKTL1WnABV3U+0XSGOC/knaGRaQD2irgO5soc3Eusz+wqEh/P/CzYnosxdgh6TL5umL648D/bVL/VcCSYnoBMDC/fw/5bLFSZqMrANI3tGcD85rkH5LrnUu6XF6e+95O53DN/s3y5+lbSWdrQ3Idj9B5NbEj6Szo90X550gHy0b+lU3W1V3A45V+D87r8Ivk8dYi/wrg6036fB8pmF7QZBnTSMNc1X6vJh1MXu13NX/R76Ma21Wl3zuTzua/VOn3TkX+lU3WxVvz535hZV0PJN1LuLDIewcwstx2i37fSzqwfrEsk+cfSBqu+3zR50Wkg9yzbPwwxJCc/8LGui7mrQC+WKzrZaSA1NiXnqPzO0oiXWlW97GO/Hl/h3TgXZLTl5MCw3eK/C/lz2FVkb4ip2+wrxZlFpNOWFaRTh5ezH1tDEetJ53JV9vUWPYjeXpRXsZ6UoBt5H8lt7XRppfY8PgRVI4feT3/pk+Pi31Z2bb8yhvao6Sz58ZN4HdUNtjyJqVIkf0blXra6Hwi5o2kg85xlTxHsOHl5q7A7sX7e+i8gXcXcEh+fzFweVHuJuD0YvpQ0lMNu+T2XQ98Os/bL//9o7zxlU/6XE7nEzCXUtwELvJsEABIZ0cPsvHN26HF57UE+GGTHb8cAhpYpC8FbsrTZ5OGBYaQhq8W07nTj87tKZc7n86A8QdgVqPfxbp6EDij0u/ZwDdIV1aNp1Ua+dvJB8NKn7/fZL0PbWwPpKc7fljdTihu9JMOvo308yv9/l1Of2uj3/n1c/LTP5V+T8v5jyr6/XbSWfUOpKuzh4DjgG8CF+c8/5TrP67Ydu8m3QR+ddsl3Zf4fZ5fpo+ic1u/gHRAO44N94HHSIHjOOAdRfrfF/kvpHN4509IQa5Rz18Dd1aW+1DR5jGkK6Tj6NzGd8ifye+abN/XsPGJzB0UQ0DFem6j2FfJ23Z+/+ncrg2GcvO8cghoYJH+TeCJcvvO7z9GOkEot+87K8ueTw6S+bN+trJf70Dans6otqdXx8W+rGxbf5GGIR4iHSA/V6R/j3Rwepl0UDgTeB8pCjceC5uTy7+TNG59P+ks8AtNlvPqis3TbyHt9I1HR8tlDyeNUd5PulHVGCPfhXRms2el7i+Rdta5pBu6b8jpd+WN+hk6H6Ns9GUf0uXm86Sdr5x3IuksJPLrpZy+kM7HN4N0GX0m6Wzo2Zy+nhQEzqx8jtFIz23cKD8pCC8q6u8o6nmUdBZYXR8rinqW5fRzSWfdQTqjKtfVsTl9Te77AzntIjrP4l4m3eNo/JRJ43HO1aSzxkZddxTpz+X1eCwbbidri2XcWsk/N6cfUXzOjbHmsp72Sh8+VdTzIukM/1hSIH+JzpvEjSd/3l+s5xfofDz0xNy3Rp+fL8o0xtsbj13OyOnT2fDx0MtzerkPrKXzsdGfVPI3Au4IOs+4VwP/UdTzdO7zq/sS6Wqk8fjpKjrvAZxL2n8fYsMnbhrb98Oke2zTij63589iBWmYiPyZL86f8ULgsZz+o9yO+3NfPkz3AeA/8zq/nzQsOD2n9yed2c/N7b23KP9tUoA4oujD+3Lbf0faj+9q0udLyEGkr17+KQgzs5qq001gMzMrOACYmdWUA4CZWU05AJiZ1ZQDgJlZTTkAmJnVlAOAmVlN/X+5CGfSg+Fi7QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "objects = [str(_) for _ in range(NBUCKETS)]\n",
    "y_pos = np.arange(len(objects))\n",
    "predictions = list(preds)\n",
    "\n",
    "plt.bar(y_pos, predictions, align='center', alpha=0.5)\n",
    "plt.xticks(y_pos, objects)\n",
    "plt.title('Baby weight prediction')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Restricting the prediction range\n",
    "\n",
    "One way to restrict the prediction range is to make the last-but-one activation function sigmoid instead, and add a lambda layer to scale the (0,1) values to the desired range. The drawback is that it will be difficult for the neural network to reach the extreme values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 1s 259us/sample - loss: 12.7289\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 131us/sample - loss: 9.4002\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 143us/sample - loss: 6.7786\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 117us/sample - loss: 4.5199\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 145us/sample - loss: 3.1557\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 143us/sample - loss: 2.2014\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 116us/sample - loss: 1.5578\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 125us/sample - loss: 1.1570\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 118us/sample - loss: 0.8444\n",
      "Train on 2048 samples\n",
      "2048/2048 [==============================] - 0s 175us/sample - loss: 0.6425\n",
      "min=3.029171943664551 max=19.990720748901367\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "MIN_Y =  3\n",
    "MAX_Y = 20\n",
    "input_size = 10\n",
    "inputs = keras.layers.Input(shape=(input_size,))\n",
    "h1 = keras.layers.Dense(20, 'relu')(inputs)\n",
    "h2 = keras.layers.Dense(1, 'sigmoid')(h1)  # 0-1 range\n",
    "output = keras.layers.Lambda(lambda y : (y*(MAX_Y-MIN_Y) + MIN_Y))(h2) # scaled\n",
    "model = keras.Model(inputs, output)\n",
    "\n",
    "# fit the model\n",
    "model.compile(optimizer='adam', loss='mse')\n",
    "batch_size = 2048\n",
    "for i in range(0, 10):\n",
    "    x = np.random.rand(batch_size, input_size)\n",
    "    y = 0.5*(x[:,0] + x[:,1]) * (MAX_Y-MIN_Y) + MIN_Y\n",
    "    model.fit(x, y)\n",
    "\n",
    "# verify\n",
    "min_y = np.finfo(np.float64).max\n",
    "max_y = np.finfo(np.float64).min\n",
    "for i in range(0, 10):\n",
    "    x = np.random.randn(batch_size, input_size)\n",
    "    y = model.predict(x)\n",
    "    min_y = min(y.min(), min_y)\n",
    "    max_y = max(y.max(), max_y)\n",
    "print('min={} max={}'.format(min_y, max_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
